from math import isclose
from pathlib import Path
from typing import cast

import numpy as np
import pytest
import torch
import torch.nn as nn
from PIL import Image
from tests.foundationals.segment_anything.utils import (
    FacebookSAM,
    FacebookSAMPredictor,
    SAMPrompt,
    intersection_over_union,
)
from torch import Tensor

import refiners.fluxion.layers as fl
from refiners.conversion.model_converter import ModelConverter
from refiners.fluxion import manual_seed
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer

# See predictor_example.ipynb official notebook
PROMPTS: list[SAMPrompt] = [
    SAMPrompt(foreground_points=((500, 375),)),
    SAMPrompt(background_points=((500, 375),)),
    SAMPrompt(foreground_points=((500, 375), (1125, 625))),
    SAMPrompt(foreground_points=((500, 375),), background_points=((1125, 625),)),
    SAMPrompt(box_points=[[(425, 600), (700, 875)]]),
    SAMPrompt(box_points=[[(425, 600), (700, 875)]], background_points=((575, 750),)),
]


@pytest.fixture(params=PROMPTS)
def prompt(request: pytest.FixtureRequest) -> SAMPrompt:
    return request.param


@pytest.fixture
def one_prompt() -> SAMPrompt:
    # Using the third prompt of the PROMPTS list in order to strictly do the same test as the official notebook in the
    # test_predictor_dense_mask test.
    return PROMPTS[2]


@pytest.fixture(scope="module")
def facebook_sam_h(sam_h_unconverted_weights_path: Path, test_device: torch.device) -> FacebookSAM:
    from segment_anything import build_sam_vit_h  # type: ignore

    sam_h = cast(FacebookSAM, build_sam_vit_h())
    sam_h.load_state_dict(state_dict=load_tensors(sam_h_unconverted_weights_path))
    return sam_h.to(device=test_device)


@pytest.fixture(scope="module")
def facebook_sam_h_predictor(facebook_sam_h: FacebookSAM) -> FacebookSAMPredictor:
    from segment_anything import SamPredictor  # type: ignore
    from segment_anything.modeling import Sam  # type: ignore

    predictor = SamPredictor(cast(Sam, facebook_sam_h))  # type: ignore
    return cast(FacebookSAMPredictor, predictor)


@pytest.fixture(scope="module")
def sam_h(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH:
    sam_h = SegmentAnythingH(device=test_device)
    sam_h.load_from_safetensors(tensors_path=sam_h_weights_path)
    return sam_h


@pytest.fixture(scope="module")
def sam_h_single_output(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH:
    sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
    sam_h.load_from_safetensors(tensors_path=sam_h_weights_path)
    return sam_h


@pytest.fixture(scope="module")
def truck(ref_path: Path) -> Image.Image:
    return Image.open(ref_path / "truck.jpg").convert("RGB")  # type: ignore


@no_grad()
def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
    manual_seed(seed=0)
    x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device)

    attention = cast(nn.Module, facebook_sam_h.image_encoder.blocks[0].attn)

    refiners_attention = FusedSelfAttention(
        embedding_dim=1280, num_heads=16, spatial_size=(14, 14), device=facebook_sam_h.device
    )

    rpa = refiners_attention.layer("RelativePositionAttention", RelativePositionAttention)
    linear_1 = refiners_attention.layer("Linear_1", fl.Linear)
    linear_2 = refiners_attention.layer("Linear_2", fl.Linear)

    linear_1.weight = attention.qkv.weight
    linear_1.bias = attention.qkv.bias
    linear_2.weight = attention.proj.weight
    linear_2.bias = attention.proj.bias
    rpa.horizontal_embedding = attention.rel_pos_w
    rpa.vertical_embedding = attention.rel_pos_h

    y_1 = attention(x)
    assert y_1.shape == x.shape

    y_2 = refiners_attention(x)
    assert y_2.shape == x.shape

    assert torch.equal(input=y_1, other=y_2)


def test_mask_decoder_arg() -> None:
    mask_decoder_default = MaskDecoder()
    sam_h = SegmentAnythingH(mask_decoder=mask_decoder_default)

    assert sam_h.mask_decoder == mask_decoder_default


def test_multimask_output_error() -> None:
    mask_decoder_multimask_output = MaskDecoder(multimask_output=True)
    with pytest.raises(AssertionError, match="multimask_output"):
        SegmentAnythingH(mask_decoder=mask_decoder_multimask_output, multimask_output=False)


@no_grad()
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
    resized = truck.resize(size=(1024, 1024))  # type: ignore
    image_tensor = image_to_tensor(image=resized, device=facebook_sam_h.device)
    y_1 = facebook_sam_h.image_encoder(image_tensor)
    y_2 = sam_h.image_encoder(image_tensor)

    assert torch.allclose(input=y_1, other=y_2, atol=1e-4)


@no_grad()
def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
    facebook_prompt_encoder = facebook_sam_h.prompt_encoder
    refiners_prompt_encoder = sam_h.point_encoder

    facebook_dense_pe: Tensor = cast(Tensor, facebook_prompt_encoder.get_dense_pe())  # type: ignore
    refiners_dense_pe = refiners_prompt_encoder.get_dense_positional_embedding(image_embedding_size=(64, 64))

    assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)


@no_grad()
def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
    facebook_prompt_encoder = facebook_sam_h.prompt_encoder
    refiners_prompt_encoder = sam_h.mask_encoder

    _, facebook_dense_pe = facebook_prompt_encoder(points=None, boxes=None, masks=None)
    refiners_dense_pe = refiners_prompt_encoder.get_no_mask_dense_embedding(image_embedding_size=(64, 64))

    assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)


@no_grad()
def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, prompt: SAMPrompt) -> None:
    facebook_prompt_encoder = facebook_sam_h.prompt_encoder
    refiners_prompt_encoder = sam_h.point_encoder

    facebook_sparse_pe, _ = facebook_prompt_encoder(
        **prompt.facebook_prompt_encoder_kwargs(device=facebook_sam_h.device)
    )

    prompt_dict = prompt.__dict__
    # Skip mask prompt, if any, since the point encoder only consumes points and boxes
    # TODO: split `SAMPrompt` and introduce a dedicated one for dense prompts
    prompt_dict.pop("low_res_mask", None)

    assert prompt_dict is not None, "`test_point_encoder` cannot be called with just a `low_res_mask`"

    coordinates, type_mask = refiners_prompt_encoder.points_to_tensor(**prompt_dict)
    # Shift to center of pixel + normalize in [0, 1] (see `_embed_points` in segment-anything official repo)
    coordinates[:, :, 0] = (coordinates[:, :, 0] + 0.5) / 1024.0
    coordinates[:, :, 1] = (coordinates[:, :, 1] + 0.5) / 1024.0
    refiners_prompt_encoder.set_type_mask(type_mask=type_mask)
    refiners_sparse_pe = refiners_prompt_encoder(coordinates)

    assert torch.equal(input=refiners_sparse_pe, other=facebook_sparse_pe)


@no_grad()
def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
    dense_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
    dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
    sparse_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device)

    refiners_layer = TwoWayTransformerLayer(
        embedding_dim=256, feed_forward_dim=2048, num_heads=8, device=facebook_sam_h.device
    )
    facebook_layer = facebook_sam_h.mask_decoder.transformer.layers[1]  # type: ignore
    assert isinstance(facebook_layer, nn.Module)

    refiners_layer.set_context(
        context="mask_decoder",
        value={
            "dense_embedding": dense_embedding,
            "dense_positional_embedding": dense_positional_embedding,
            "sparse_embedding": sparse_embedding,
        },
    )
    facebook_inputs = {
        "queries": sparse_embedding,
        "keys": dense_embedding,
        "query_pe": sparse_embedding,
        "key_pe": dense_positional_embedding,
    }

    converter = ModelConverter(
        source_model=facebook_layer,
        target_model=refiners_layer,
        skip_output_check=True,  # done below, manually
    )

    assert converter.run(source_args=facebook_inputs, target_args=(sparse_embedding,))

    refiners_layer.set_context(
        context="mask_decoder",
        value={
            "dense_embedding": dense_embedding,
            "dense_positional_embedding": dense_positional_embedding,
            "sparse_embedding": sparse_embedding,
        },
    )
    y_1 = facebook_layer(**facebook_inputs)[0]
    y_2 = refiners_layer(sparse_embedding)[0]

    assert torch.equal(input=y_1, other=y_2)


@no_grad()
def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
    manual_seed(seed=0)
    facebook_mask_decoder = facebook_sam_h.mask_decoder
    refiners_mask_decoder = sam_h.mask_decoder

    image_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device)
    dense_positional_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device)
    point_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device)
    mask_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device)

    from segment_anything.modeling.common import LayerNorm2d  # type: ignore

    assert issubclass(LayerNorm2d, nn.Module)
    custom_layers = {LayerNorm2d: fl.LayerNorm2d}

    converter = ModelConverter(
        source_model=facebook_mask_decoder,
        target_model=refiners_mask_decoder,
        custom_layer_mapping=custom_layers,  # type: ignore
    )

    inputs = {
        "image_embeddings": image_embedding,
        "image_pe": dense_positional_embedding,
        "sparse_prompt_embeddings": point_embedding,
        "dense_prompt_embeddings": mask_embedding,
        "multimask_output": True,
    }

    refiners_mask_decoder.set_image_embedding(image_embedding)
    refiners_mask_decoder.set_point_embedding(point_embedding)
    refiners_mask_decoder.set_mask_embedding(mask_embedding)
    refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)

    mapping = converter.map_state_dicts(source_args=inputs, target_args={})
    assert mapping is not None
    mapping["MaskDecoderTokens.Parameter"] = "iou_token"

    state_dict = converter._convert_state_dict(  # type: ignore
        source_state_dict=facebook_mask_decoder.state_dict(),
        target_state_dict=refiners_mask_decoder.state_dict(),
        state_dict_mapping=mapping,
    )
    state_dict["MaskDecoderTokens.Parameter.weight"] = torch.cat(
        [facebook_mask_decoder.iou_token.weight, facebook_mask_decoder.mask_tokens.weight], dim=0
    )  # type: ignore
    refiners_mask_decoder.load_state_dict(state_dict=state_dict)

    facebook_output = facebook_mask_decoder(**inputs)

    refiners_mask_decoder.set_image_embedding(image_embedding)
    refiners_mask_decoder.set_point_embedding(point_embedding)
    refiners_mask_decoder.set_mask_embedding(mask_embedding)
    refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
    mask_prediction, iou_prediction = refiners_mask_decoder()

    facebook_masks = facebook_output[0]
    facebook_prediction = facebook_output[1]

    assert torch.equal(input=mask_prediction, other=facebook_masks)
    assert torch.equal(input=iou_prediction, other=facebook_prediction)


def test_predictor(
    facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, prompt: SAMPrompt
) -> None:
    predictor = facebook_sam_h_predictor
    predictor.set_image(np.array(truck))
    facebook_masks, facebook_scores, _ = predictor.predict(**prompt.facebook_predict_kwargs())  # type: ignore

    assert len(facebook_masks) == 3

    masks, scores, _ = sam_h.predict(truck, **prompt.__dict__)
    masks = masks.squeeze(0)
    scores = scores.squeeze(0)

    assert len(masks) == 3

    for i in range(3):
        mask_prediction = masks[i].cpu()
        facebook_mask = torch.as_tensor(facebook_masks[i])
        iou = intersection_over_union(mask_prediction, facebook_mask)
        assert isclose(iou, 1.0, rel_tol=5e-04), f"iou: {iou}"
        assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-04)


def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None:
    masks_ref, scores_ref, _ = sam_h.predict(truck, **one_prompt.__dict__)

    image_embedding = sam_h.compute_image_embedding(truck)
    masks, scores, _ = sam_h.predict(image_embedding, **one_prompt.__dict__)

    assert torch.equal(masks, masks_ref)
    assert torch.equal(scores_ref, scores)


def test_predictor_dense_mask(
    facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
) -> None:
    """
    NOTE : Binarizing intermediate masks isn't necessary, as per SamPredictor.predict_torch docstring:
    > mask_input (np.ndarray): A low resolution mask input to the model, typically
    >         coming from a previous prediction iteration. Has form Bx1xHxW, where
    >         for SAM, H=W=256. Masks returned by a previous iteration of the
    >         predict method do not need further transformation.
    """
    predictor = facebook_sam_h_predictor
    predictor.set_image(np.array(truck))
    facebook_masks, facebook_scores, facebook_logits = predictor.predict(
        **one_prompt.facebook_predict_kwargs(),  # type: ignore
        multimask_output=True,
    )

    assert len(facebook_masks) == 3

    facebook_mask_input = facebook_logits[np.argmax(facebook_scores)]  # shape: HxW

    # Using the same mask coordinates inputs as the official notebook
    facebook_prompt = SAMPrompt(
        foreground_points=((500, 375),), background_points=((1125, 625),), low_res_mask=facebook_mask_input[None, ...]
    )
    facebook_dense_masks, _, _ = predictor.predict(**facebook_prompt.facebook_predict_kwargs(), multimask_output=True)  # type: ignore

    assert len(facebook_dense_masks) == 3

    masks, scores, logits = sam_h.predict(truck, **one_prompt.__dict__)
    masks = masks.squeeze(0)
    scores = scores.squeeze(0)

    assert len(masks) == 3

    mask_input = logits[:, scores.max(dim=0).indices, ...]  # shape: 1xHxW

    assert np.allclose(
        mask_input.cpu(), facebook_mask_input, atol=1e-1
    )  # Lower doesn't pass, but it's close enough for logits

    refiners_prompt = SAMPrompt(
        foreground_points=((500, 375),), background_points=((1125, 625),), low_res_mask=mask_input.unsqueeze(0)
    )
    dense_masks, _, _ = sam_h.predict(truck, **refiners_prompt.__dict__)
    dense_masks = dense_masks.squeeze(0)

    assert len(dense_masks) == 3

    for i in range(3):
        dense_mask_prediction = dense_masks[i].cpu()
        facebook_dense_mask = torch.as_tensor(facebook_dense_masks[i])
        assert dense_mask_prediction.shape == facebook_dense_mask.shape
        assert isclose(intersection_over_union(dense_mask_prediction, facebook_dense_mask), 1.0, rel_tol=5e-05)


def test_predictor_single_output(
    facebook_sam_h_predictor: FacebookSAMPredictor,
    sam_h_single_output: SegmentAnythingH,
    truck: Image.Image,
    one_prompt: SAMPrompt,
) -> None:
    predictor = facebook_sam_h_predictor
    predictor.set_image(np.array(truck))

    facebook_masks, facebook_scores, facebook_low_res_masks = predictor.predict(  # type: ignore
        **one_prompt.facebook_predict_kwargs(),  # type: ignore
        multimask_output=False,
    )

    assert len(facebook_masks) == 1

    masks, scores, low_res_masks = sam_h_single_output.predict(truck, **one_prompt.__dict__)
    masks = masks.squeeze(0)
    scores = scores.squeeze(0)

    assert len(masks) == 1

    assert torch.allclose(
        low_res_masks[0, 0, ...],
        torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device),
        atol=5e-2,  # see test_predictor_resized_single_output for more explanation
    )
    assert isclose(scores[0].item(), facebook_scores[0].item(), abs_tol=1e-05)

    mask_prediction = masks[0].cpu()
    facebook_mask = torch.as_tensor(facebook_masks[0])
    assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)


def test_predictor_resized_single_output(
    facebook_sam_h_predictor: FacebookSAMPredictor,
    sam_h_single_output: SegmentAnythingH,
    truck: Image.Image,
    one_prompt: SAMPrompt,
) -> None:
    # The refiners implementation of SAM differs from official
    # implementation by a 6e-3 absolute diff (see test_predictor_single_output)
    # This diff is related to 2 components :
    # * image_encoder (see test_image_encoder)
    # * point rescaling (facebook uses numpy while refiners uses torch)
    #
    # Current test is designed to workaround those 2 components
    # * facebook image_embedding is used
    # * the image is pre-resized by (1024, 1024) so there is no rescaling
    # Then the test pass with torch.equal

    predictor = facebook_sam_h_predictor
    size = (1024, 1024)
    resized_truck = truck.resize(size)  # type: ignore
    predictor.set_image(np.array(resized_truck))

    _, _, facebook_low_res_masks = predictor.predict(  # type: ignore
        **one_prompt.facebook_predict_kwargs(),  # type: ignore
        multimask_output=False,
    )

    facebook_image_embedding = ImageEmbedding(features=predictor.features, original_image_size=size)

    _, _, low_res_masks = sam_h_single_output.predict(facebook_image_embedding, **one_prompt.__dict__)

    assert torch.equal(
        low_res_masks[0, 0, ...],
        torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device),
    )


def test_mask_encoder(
    facebook_sam_h_predictor: FacebookSAMPredictor,
    sam_h: SegmentAnythingH,
    truck: Image.Image,
    one_prompt: SAMPrompt,
) -> None:
    predictor = facebook_sam_h_predictor
    predictor.set_image(np.array(truck))
    _, facebook_scores, facebook_logits = predictor.predict(
        **one_prompt.facebook_predict_kwargs(),  # type: ignore
        multimask_output=True,
    )
    facebook_mask_input = facebook_logits[np.argmax(facebook_scores)]
    facebook_mask_input = (
        torch.from_numpy(facebook_mask_input)  # type: ignore
        .to(device=predictor.model.device)
        .unsqueeze(0)
        .unsqueeze(0)  # shape: 1x1xHxW
    )

    _, fb_dense_embeddings = predictor.model.prompt_encoder(
        points=None,
        boxes=None,
        masks=facebook_mask_input,
    )

    _, scores, logits = sam_h.predict(truck, **one_prompt.__dict__)
    scores = scores.squeeze(0)
    mask_input = logits[:, scores.max(dim=0).indices, ...].unsqueeze(0)  # shape: 1x1xHxW
    dense_embeddings = sam_h.mask_encoder(mask_input)

    assert facebook_mask_input.shape == mask_input.shape
    assert torch.allclose(dense_embeddings, fb_dense_embeddings, atol=1e-3)


@no_grad()
def test_batch_mask_decoder(sam_h: SegmentAnythingH) -> None:
    batch_size = 5

    image_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
    mask_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
    dense_positional_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(
        batch_size, 1, 1, 1
    )
    point_embedding = torch.randn(1, 2, 256, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1)

    sam_h.mask_decoder.set_image_embedding(image_embedding)
    sam_h.mask_decoder.set_mask_embedding(mask_embedding)
    sam_h.mask_decoder.set_point_embedding(point_embedding)
    sam_h.mask_decoder.set_dense_positional_embedding(dense_positional_embedding)

    mask_prediction, iou_prediction = sam_h.mask_decoder()

    assert mask_prediction.shape == (batch_size, 3, 256, 256)
    assert iou_prediction.shape == (batch_size, 3)
    assert torch.equal(mask_prediction[0], mask_prediction[1])
